- 
          
 - 
                Notifications
    
You must be signed in to change notification settings  - Fork 11k
 
perf: optimize rejection sampling triton kernel #25791
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request optimizes the sample_recovered_tokens_kernel Triton kernel by unrolling the computation over the vocabulary size. This is a good performance optimization for large vocabularies. The implementation of the online maximum calculation is correct. However, I've identified a performance issue where a value is redundantly loaded within a loop. My review includes a suggestion to fix this.
| for off in range(0, vocab_size, BLOCK_SIZE): | ||
| vocab_offset = off + tl.arange(0, BLOCK_SIZE) | ||
| if NO_DRAFT_PROBS: | ||
| draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) | ||
| prob = tl.load( | ||
| target_probs_ptr + (start_idx + pos) * vocab_size + | ||
| vocab_offset, | ||
| mask=((vocab_offset < vocab_size) & | ||
| (vocab_offset != draft_token_id)), | ||
| other=0, | ||
| ) | ||
| else: | ||
| draft_prob = tl.load( | ||
| draft_probs_ptr + (start_idx + pos) * vocab_size + | ||
| vocab_offset, | ||
| mask=vocab_offset < vocab_size, | ||
| other=0, | ||
| ) | ||
| target_prob = tl.load( | ||
| target_probs_ptr + (start_idx + pos) * vocab_size + | ||
| vocab_offset, | ||
| mask=vocab_offset < vocab_size, | ||
| other=float("-inf")) | ||
| recovered_id = tl.argmax(prob / q, axis=-1) | ||
| tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) | ||
| other=0, | ||
| ) | ||
| prob = tl.maximum(target_prob - draft_prob, 0) | ||
| 
               | 
          ||
| q = tl.load( | ||
| q_ptr + req_idx * vocab_size + vocab_offset, | ||
| mask=vocab_offset < vocab_size, | ||
| other=float("-inf"), | ||
| ) | ||
| scores = prob / q | ||
| local_val = tl.max(scores, axis=-1) | ||
| local_idx = tl.argmax(scores, axis=-1) + off | ||
| 
               | 
          ||
| # update global max | ||
| better = local_val > max_val | ||
| max_val = tl.where(better, local_val, max_val) | ||
| max_idx = tl.where(better, local_idx, max_idx) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For performance, draft_token_id should be loaded only once before the loop since its value doesn't change across iterations. Loading it inside the loop results in redundant reads from global memory, which can negatively impact kernel performance.
    if NO_DRAFT_PROBS:
        draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
    for off in range(0, vocab_size, BLOCK_SIZE):
        vocab_offset = off + tl.arange(0, BLOCK_SIZE)
        if NO_DRAFT_PROBS:
            prob = tl.load(
                target_probs_ptr + (start_idx + pos) * vocab_size +
                vocab_offset,
                mask=((vocab_offset < vocab_size) &
                      (vocab_offset != draft_token_id)),
                other=0,
            )
        else:
            draft_prob = tl.load(
                draft_probs_ptr + (start_idx + pos) * vocab_size +
                vocab_offset,
                mask=vocab_offset < vocab_size,
                other=0,
            )
            target_prob = tl.load(
                target_probs_ptr + (start_idx + pos) * vocab_size +
                vocab_offset,
                mask=vocab_offset < vocab_size,
                other=0,
            )
            prob = tl.maximum(target_prob - draft_prob, 0)
        q = tl.load(
            q_ptr + req_idx * vocab_size + vocab_offset,
            mask=vocab_offset < vocab_size,
            other=float("-inf"),
        )
        scores = prob / q
        local_val = tl.max(scores, axis=-1)
        local_idx = tl.argmax(scores, axis=-1) + off
        # update global max
        better = local_val > max_val
        max_val = tl.where(better, local_val, max_val)
        max_idx = tl.where(better, local_idx, max_idx)
Purpose
sample_recovered_tokens_kernelimplementation by unrolling CTA over thevocab_sizedimension.num_warps, could be helpful for small batch sizes.Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.